import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
import os

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

class CustomResNet18(nn.Module):
    def __init__(self, num_classes=100):
        super(CustomResNet18, self).__init__()
        self.resnet = models.resnet18(pretrained=True)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)

    def forward(self, x):
        return self.resnet(x)

class CustomCrossEntropyLoss(nn.Module):
    def __init__(self, given_labels,u, num_classes=100):
        super(CustomCrossEntropyLoss, self).__init__()
        self.given_labels = set(given_labels)
        self.ce = nn.CrossEntropyLoss(reduction='none')
        self.num_classes = num_classes
        self.u=u

    def forward(self, outputs, targets):
        device = outputs.device
        mask = torch.tensor([label.item() in self.given_labels for label in targets], device=device)
        uniform_targets = torch.ones_like(outputs) / self.num_classes
        #########
        onehots=torch.zeros_like(outputs)
        onehots.scatter_(1, targets.unsqueeze(1), 1)
        
        ce_loss = self.ce(outputs, targets)
       # uniform_loss = self.ce(outputs, uniform_targets)
        uniform_loss = self.ce(outputs, (1-self.u)*uniform_targets+self.u*(onehots))

        loss = torch.where(mask, ce_loss, uniform_loss)
        return loss.mean()

def train(rank, world_size, given_labels,u):
    setup(rank, world_size)
    torch.cuda.set_device(rank)
    device = torch.device(f"cuda:{rank}")
    
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
    ])

    trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
    train_sampler = torch.utils.data.distributed.DistributedSampler(trainset, num_replicas=world_size, rank=rank)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, sampler=train_sampler)

    net = CustomResNet18().to(device)
    net = DDP(net, device_ids=[rank])
    criterion = CustomCrossEntropyLoss(given_labels,u).to(device)
    optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

    for epoch in range(1):  # 训练200个epoch
        net.train()
        train_sampler.set_epoch(epoch)
        for inputs, targets in trainloader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

        scheduler.step()

        if rank == 0 and (epoch + 1) % 10 == 0:
            print(f'Epoch {epoch+1} completed')

    if rank == 0:
        torch.save(net.module.state_dict(), f'n{u}_cifar100_resnet18_model{given_labels[0]}.pth')

    dist.destroy_process_group()

def test(given_labels,u):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
    ])

    testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)

    net = CustomResNet18().to(device)
    net.load_state_dict(torch.load(f'n{u}_cifar100_resnet18_model{given_labels[0]}.pth'))
    net.eval()

    correct = 0
    total = 0
    given_correct = 0
    given_total = 0
    given_sum = 0
    other_correct=0
    other_total=0
  #  other_sum=0

    with torch.no_grad():
        for inputs, targets in testloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            given_mask = torch.tensor([label.item() in given_labels for label in targets], device=device)
            given_total += given_mask.sum().item()
      #      other_total+= not given_mask.sum().item()
            given_correct += (predicted.eq(targets) & given_mask).sum().item()
      #      other_correct+=(predicted.eq(targets)&(not given_mask)).sum().item()
            
            softmax_outputs = torch.nn.functional.softmax(outputs, dim=1)
            given_sum += softmax_outputs.max(1)[0][given_mask].sum().item()
           # other_sum += softmax_outputs.max(1)[0][not given_mask].sum().item()

    print(f'Overall Accuracy: {100 * correct / total:.2f}%')
    print(f'Given Labels Accuracy: {100 * given_correct / given_total:.2f}%, given_total: {given_total}')
 #   print(f'Other Labels Accuracy: {100 * other_correct / other_total:.2f}%, other_total: {other_total}')
    print(f'Average of max softmax values for given labels: {given_sum/given_total:.2f}')   
  #  print(f'Average of max softmax values for other labels: {other_sum/other_total:.2f}')

if __name__ == '__main__':
    world_size = torch.cuda.device_count()
    labels= list(range(10))
    for u in range(10):
        for i in range(10):
            given_labels = [x+10*i for x in labels]  # 示例给定的10个标签
            torch.multiprocessing.spawn(train, args=(world_size, given_labels,(u+1)/10), nprocs=world_size, join=True)
            test(given_labels,u)